from nsfw_detector import predict
from tqdm import tqdm
import sys,os
import pickle


img_list = sys.argv[1] # imglist
P = sys.argv[2]        # slurm id
save_dir = sys.argv[3] # save dir
if img_list != '':
    f_list = open(img_list, 'r')
    imgs = f_list.readlines()
    imgs = [_d.strip() for _d in imgs]
print(f'done preparing dataset!, N={len(imgs)}, will save to {save_dir}/{P}.pkl', flush=True)

model = predict.load_model('nsfw_mobilenet2.224x224.h5')
dets = []
BS = 50
parts = len(imgs)//BS
for i in tqdm(range(0, parts )):
    dets.append(predict.classify(model, imgs[i*BS: (i+1)*BS]))
    if i == parts-1:
        dets.append(predict.classify(model, imgs[i*BS: ]))


pickle.dump(dets, open(os.path.join(save_dir, f'dets{P}.pkl'),'wb'))


if False:
    import numpy as np
    import pickle
    import pandas as pd
    dets = []
    save_dir = xx
    for P in range(60):
        dets.extend(pickle.load(open(os.path.join(save_dir, f'dets{P}.pkl'),'rb')))
    df = pd.DataFrame(dets)
    psort = np.sort(df.loc['porn'].values)[::-1]
    hsort = np.sort(df.loc['hentai'].values)[::-1]
    print(df.columns[df.loc['porn'] > psort[100] ])
    print(df.columns[df.loc['hentai'] > hsort[100] ])
